#!/usr/bin/env python3
"""
Copyright 2018 Google Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import base64
from contextlib import closing
import json
import netrc
import os
import os.path
import platform
import re
import shutil
import subprocess
import sys
import tempfile
import time

try:
    from urllib.parse import urlparse
    from urllib.request import urlopen, Request
except ImportError:
    # Python 2.x compatibility hack.
    # http://python-future.org/compatible_idioms.html?highlight=urllib#urllib-module
    from urlparse import urlparse
    from urllib2 import urlopen, Request


# Version and LooseVersion are from setuptools ( https://github.com/pypa/setuputils )
# and MIT licensed
class Version:
    """Abstract base class for version numbering classes.  Just provides
    constructor (__init__) and reproducer (__repr__), because those
    seem to be the same for all version numbering classes; and route
    rich comparisons to _cmp.
    """

    def __init__(self, vstring=None):
        if vstring:
            self.parse(vstring)

    def __repr__(self):
        return "{} ('{}')".format(self.__class__.__name__, str(self))

    def __eq__(self, other):
        c = self._cmp(other)
        if c is NotImplemented:
            return c
        return c == 0

    def __lt__(self, other):
        c = self._cmp(other)
        if c is NotImplemented:
            return c
        return c < 0

    def __le__(self, other):
        c = self._cmp(other)
        if c is NotImplemented:
            return c
        return c <= 0

    def __gt__(self, other):
        c = self._cmp(other)
        if c is NotImplemented:
            return c
        return c > 0

    def __ge__(self, other):
        c = self._cmp(other)
        if c is NotImplemented:
            return c
        return c >= 0

class LooseVersion(Version):

    """Version numbering for anarchists and software realists.
    Implements the standard interface for version number classes as
    described above.  A version number consists of a series of numbers,
    separated by either periods or strings of letters.  When comparing
    version numbers, the numeric components will be compared
    numerically, and the alphabetic components lexically.  The following
    are all valid version numbers, in no particular order:

        1.5.1
        1.5.2b2
        161
        3.10a
        8.02
        3.4j
        1996.07.12
        3.2.pl0
        3.1.1.6
        2g6
        11g
        0.960923
        2.2beta29
        1.13++
        5.5.kw
        2.0b1pl0

    In fact, there is no such thing as an invalid version number under
    this scheme; the rules for comparison are simple and predictable,
    but may not always give the results you want (for some definition
    of "want").
    """

    component_re = re.compile(r'(\d+ | [a-z]+ | \.)', re.VERBOSE)

    def parse(self, vstring):
        # I've given up on thinking I can reconstruct the version string
        # from the parsed tuple -- so I just store the string here for
        # use by __str__
        self.vstring = vstring
        components = [x for x in self.component_re.split(vstring) if x and x != '.']
        for i, obj in enumerate(components):
            try:
                components[i] = int(obj)
            except ValueError:
                pass

        self.version = components

    def __str__(self):
        return self.vstring

    def __repr__(self):
        return "LooseVersion ('%s')" % str(self)

    def _cmp(self, other):
        if isinstance(other, str):
            other = LooseVersion(other)
        elif not isinstance(other, LooseVersion):
            return NotImplemented

        if self.version == other.version:
            return 0
        if self.version < other.version:
            return -1
        if self.version > other.version:
            return 1


ONE_HOUR = 1 * 60 * 60

LATEST_PATTERN = re.compile(r"latest(-(?P<offset>\d+))?$")

LAST_GREEN_COMMIT_BASE_PATH = (
    "https://storage.googleapis.com/bazel-untrusted-builds/last_green_commit/"
)

LAST_GREEN_COMMIT_PATH_SUFFIXES = {
    "last_green": "github.com/bazelbuild/bazel.git/bazel-bazel",
    "last_downstream_green": "downstream_pipeline",
}

BAZEL_GCS_PATH_PATTERN = (
    "https://storage.googleapis.com/bazel-builds/artifacts/{platform}/{commit}/bazel"
)

SUPPORTED_PLATFORMS = {"linux": "ubuntu1404", "windows": "windows", "darwin": "macos"}

TOOLS_BAZEL_PATH = "./tools/bazel"

BAZEL_REAL = "BAZEL_REAL"

BAZEL_UPSTREAM = "bazelbuild"


def decide_which_bazel_version_to_use():
    # Check in this order:
    # - env var "USE_BAZEL_VERSION" is set to a specific version.
    # - env var "USE_NIGHTLY_BAZEL" or "USE_BAZEL_NIGHTLY" is set -> latest
    #   nightly. (TODO)
    # - env var "USE_CANARY_BAZEL" or "USE_BAZEL_CANARY" is set -> latest
    #   rc. (TODO)
    # - the file workspace_root/tools/bazel exists -> that version. (TODO)
    # - workspace_root/.bazelversion exists -> read contents, that version.
    # - workspace_root/WORKSPACE contains a version -> that version. (TODO)
    # - fallback: latest release
    if "USE_BAZEL_VERSION" in os.environ:
        return os.environ["USE_BAZEL_VERSION"]

    workspace_root = find_workspace_root()
    if workspace_root:
        bazelversion_path = os.path.join(workspace_root, ".bazelversion")
        if os.path.exists(bazelversion_path):
            with open(bazelversion_path, "r") as f:
                return f.read().strip()

    return "latest"


def find_workspace_root(root=None):
    if root is None:
        root = os.getcwd()
    if os.path.exists(os.path.join(root, "WORKSPACE")):
        return root
    new_root = os.path.dirname(root)
    return find_workspace_root(new_root) if new_root != root else None


def resolve_version_label_to_number_or_commit(bazelisk_directory, version):
    """Resolves the given label to a released version of Bazel or a commit.

    Args:
        bazelisk_directory: string; path to a directory that can store
            temporary data for Bazelisk.
        version: string; the version label that should be resolved.
    Returns:
        A (string, bool) tuple that consists of two parts:
        1. the resolved number of a Bazel release (candidate), or the commit
            of an unreleased Bazel binary,
        2. An indicator for whether the returned version refers to a commit.
    """
    suffix = LAST_GREEN_COMMIT_PATH_SUFFIXES.get(version)
    if suffix:
        return get_last_green_commit(suffix), True

    if "latest" in version:
        match = LATEST_PATTERN.match(version)
        if not match:
            raise Exception(
                'Invalid version "{}". In addition to using a version '
                'number such as "0.20.0", you can use values such as '
                '"latest" and "latest-N", with N being a non-negative '
                "integer.".format(version)
            )

        history = get_version_history(bazelisk_directory)
        offset = int(match.group("offset") or "0")
        return resolve_latest_version(history, offset), False

    return version, False


def get_last_green_commit(path_suffix):
    return read_remote_text_file(LAST_GREEN_COMMIT_BASE_PATH + path_suffix).strip()


def get_releases_json(bazelisk_directory):
    """Returns the most recent versions of Bazel, in descending order."""
    releases = os.path.join(bazelisk_directory, "releases.json")

    # Use a cached version if it's fresh enough.
    if os.path.exists(releases):
        if abs(time.time() - os.path.getmtime(releases)) < ONE_HOUR:
            with open(releases, "rb") as f:
                try:
                    return json.loads(f.read().decode("utf-8"))
                except ValueError:
                    print("WARN: Could not parse cached releases.json.")
                    pass

    with open(releases, "wb") as f:
        body = read_remote_text_file("https://api.github.com/repos/bazelbuild/bazel/releases")
        f.write(body.encode("utf-8"))
        return json.loads(body)


def read_remote_text_file(url):
    with closing(urlopen(url)) as res:
        body = res.read()
        try:
            return body.decode(res.info().get_content_charset("iso-8859-1"))
        except AttributeError:
            # Python 2.x compatibility hack
            return body.decode(res.info().getparam("charset") or "iso-8859-1")


def get_version_history(bazelisk_directory):
    ordered = sorted(
        (
            LooseVersion(release["tag_name"])
            for release in get_releases_json(bazelisk_directory)
            if not release["prerelease"]
        ),
        reverse=True,
    )
    return [str(v) for v in ordered]


def resolve_latest_version(version_history, offset):
    if offset >= len(version_history):
        version = "latest-{}".format(offset) if offset else "latest"
        raise Exception(
            'Cannot resolve version "{}": There are only {} Bazel '
            "releases.".format(version, len(version_history))
        )

    # This only works since we store the history in descending order.
    return version_history[offset]


def get_operating_system():
    operating_system = platform.system().lower()
    if operating_system not in ("linux", "darwin", "windows"):
        raise Exception(
            'Unsupported operating system "{}". '
            "Bazel currently only supports Linux, macOS and Windows.".format(operating_system)
        )
    return operating_system


def determine_executable_filename_suffix():
    operating_system = get_operating_system()
    return ".exe" if operating_system == "windows" else ""


def determine_bazel_filename(version):
    machine = normalized_machine_arch_name()
    if machine != "x86_64" and machine != 'arm64':
        raise Exception(
            'Unsupported machine architecture "{}". Bazel currently only supports x86_64.'.format(
                machine
            )
        )

    operating_system = get_operating_system()

    filename_suffix = determine_executable_filename_suffix()
    bazel_flavor = "bazel"
    if os.environ.get("BAZELISK_NOJDK", "0") != "0":
        bazel_flavor = "bazel_nojdk"
    return "{}-{}-{}-{}{}".format(bazel_flavor, version, operating_system, machine, filename_suffix)


def normalized_machine_arch_name():
    machine = platform.machine().lower()
    if machine == "amd64":
        machine = "x86_64"
    elif machine == "aarch64":
        machine = "arm64"
    return machine


def determine_url(version, is_commit, bazel_filename):
    if is_commit:
        sys.stderr.write("Using unreleased version at commit {}\n".format(version))
        # No need to validate the platform thanks to determine_bazel_filename().
        return BAZEL_GCS_PATH_PATTERN.format(
            platform=SUPPORTED_PLATFORMS[platform.system().lower()], commit=version
        )

    # Split version into base version and optional additional identifier.
    # Example: '0.19.1' -> ('0.19.1', None), '0.20.0rc1' -> ('0.20.0', 'rc1')
    (version, rc) = re.match(r"(\d*\.\d*(?:\.\d*)?)(rc\d+)?", version).groups()

    if "BAZELISK_BASE_URL" in os.environ:
        return "{}/{}/{}".format(
            os.environ["BAZELISK_BASE_URL"], version, bazel_filename
        )
    else:
        return "https://releases.bazel.build/{}/{}/{}".format(
            version, rc if rc else "release", bazel_filename
        )


def trim_suffix(string, suffix):
    if string.endswith(suffix):
        return string[:len(string) - len(suffix)]
    else:
        return string


def download_bazel_into_directory(version, is_commit, directory):
    bazel_filename = determine_bazel_filename(version)
    url = determine_url(version, is_commit, bazel_filename)

    filename_suffix = determine_executable_filename_suffix()
    bazel_directory_name = trim_suffix(bazel_filename, filename_suffix)
    destination_dir = os.path.join(directory, bazel_directory_name, "bin")
    maybe_makedirs(destination_dir)

    destination_path = os.path.join(destination_dir, "bazel" + filename_suffix)
    if not os.path.exists(destination_path):
        sys.stderr.write("Downloading {}...\n".format(url))
        with tempfile.NamedTemporaryFile(prefix="bazelisk", dir=destination_dir, delete=False) as t:
            # https://github.com/bazelbuild/bazelisk/issues/247
            request = Request(url)
            if "BAZELISK_BASE_URL" in os.environ:
                parts = urlparse(url)
                creds = None
                try:
                    creds = netrc.netrc().hosts.get(parts.netloc)
                except:
                    pass
                if creds is not None:
                    auth = base64.b64encode(('%s:%s' % (creds[0], creds[2])).encode('ascii'))
                    request.add_header("Authorization", "Basic %s" % auth.decode('utf-8'))
            with closing(urlopen(request)) as response:
                shutil.copyfileobj(response, t)
            t.flush()
            os.fsync(t.fileno())
        os.rename(t.name, destination_path)
        os.chmod(destination_path, 0o755)

    return destination_path


def get_bazelisk_directory():
    bazelisk_home = os.environ.get("BAZELISK_HOME")
    if bazelisk_home is not None:
        return bazelisk_home

    operating_system = get_operating_system()

    base_dir = None

    if operating_system == "windows":
        base_dir = os.environ.get("LocalAppData")
        if base_dir is None:
            raise Exception("%LocalAppData% is not defined")
    elif operating_system == "darwin":
        base_dir = os.environ.get("HOME")
        if base_dir is None:
            raise Exception("$HOME is not defined")
        base_dir = os.path.join(base_dir, "Library/Caches")
    elif operating_system == "linux":
        base_dir = os.environ.get("XDG_CACHE_HOME")
        if base_dir is None:
            base_dir = os.environ.get("HOME")
            if base_dir is None:
                raise Exception("neither $XDG_CACHE_HOME nor $HOME are defined")
            base_dir = os.path.join(base_dir, ".cache")
    else:
        raise Exception("Unsupported operating system '{}'".format(operating_system))

    return os.path.join(base_dir, "bazelisk")


def maybe_makedirs(path):
    """
  Creates a directory and its parents if necessary.
  """
    try:
        os.makedirs(path)
    except OSError as e:
        if not os.path.isdir(path):
            raise e


def delegate_tools_bazel(bazel_path):
    """Match Bazel's own delegation behavior in the builds distributed by most
    package managers: use tools/bazel if it's present, executable, and not this
    script.
    """
    root = find_workspace_root()
    if root:
        wrapper = os.path.join(root, TOOLS_BAZEL_PATH)
        if os.path.exists(wrapper) and os.access(wrapper, os.X_OK):
            try:
                if not os.path.samefile(wrapper, __file__):
                    return wrapper
            except AttributeError:
                # Python 2 on Windows does not support os.path.samefile
                if os.path.abspath(wrapper) != os.path.abspath(__file__):
                    return wrapper
    return None


def prepend_directory_to_path(env, directory):
    """
    Prepend binary directory to PATH
    """
    if "PATH" in env:
        env["PATH"] = directory + os.pathsep + env["PATH"]
    else:
        env["PATH"] = directory


def make_bazel_cmd(bazel_path, argv):
    env = os.environ.copy()

    wrapper = delegate_tools_bazel(bazel_path)
    if wrapper:
        env[BAZEL_REAL] = bazel_path
        bazel_path = wrapper

    directory = os.path.dirname(bazel_path)
    prepend_directory_to_path(env, directory)
    return {
        'exec': bazel_path,
        'args': argv,
        'env': env,
    }


def execute_bazel(bazel_path, argv):
    cmd = make_bazel_cmd(bazel_path, argv)

    # We cannot use close_fds on Windows, so disable it there.
    p = subprocess.Popen([cmd['exec']] + cmd['args'], close_fds=os.name != "nt", env=cmd['env'])
    while True:
        try:
            return p.wait()
        except KeyboardInterrupt:
            # Bazel will also get the signal and terminate.
            # We should continue waiting until it does so.
            pass


def get_bazel_path():
    bazelisk_directory = get_bazelisk_directory()
    maybe_makedirs(bazelisk_directory)

    bazel_version = decide_which_bazel_version_to_use()
    bazel_version, is_commit = resolve_version_label_to_number_or_commit(
        bazelisk_directory, bazel_version
    )

    # TODO: Support other forks just like Go version
    bazel_directory = os.path.join(bazelisk_directory, "downloads", BAZEL_UPSTREAM)
    return download_bazel_into_directory(bazel_version, is_commit, bazel_directory)


def main(argv=None):
    if argv is None:
        argv = sys.argv

    bazel_path = get_bazel_path()

    argv = argv[1:]

    if argv and argv[0] == "--print_env":
        cmd = make_bazel_cmd(bazel_path, argv)
        env = cmd['env']
        for key in env:
            print('{}={}'.format(key, env[key]))
        return 0

    return execute_bazel(bazel_path, argv)


if __name__ == "__main__":
    sys.exit(main())
